package com.simple.rag;

import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;

/**
 * 项目: middle-ware-design
 * <p>
 * 功能描述:
 *
 * @author: WuChengXing
 * @create: 2025-02-14 00:46
 **/
public class WenxinYiyanLLM {
    private static final String TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token";
    private static final String API_URL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions";
    private String apiKey;
    private String secretKey;
    private String accessToken;

    public WenxinYiyanLLM(String apiKey, String secretKey) {
        this.apiKey = apiKey;
        this.secretKey = secretKey;
        this.accessToken = getAccessToken();
    }

    private String getAccessToken() {
        try {
            URL url = new URL(TOKEN_URL + "?grant_type=client_credentials&client_id=" + apiKey + "&client_secret=" + secretKey);
            HttpURLConnection connection = (HttpURLConnection) url.openConnection();
            connection.setRequestMethod("GET");
            BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream()));
            StringBuilder response = new StringBuilder();
            String line;
            while ((line = reader.readLine()) != null) {
                response.append(line);
            }
            reader.close();
            JSONObject json = JSON.parseObject(response.toString());
            return json.getString("access_token");
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public String generateAnswer(String context, String question) {
        try {
            String prompt = "上下文信息：" + context + " 问题：" + question;
            JSONObject requestBody = new JSONObject();
            requestBody.put("messages", new JSONObject().fluentPut("role", "user").fluentPut("content", prompt));
            URL url = new URL(API_URL + "?access_token=" + accessToken);
            HttpURLConnection connection = (HttpURLConnection) url.openConnection();
            connection.setRequestMethod("POST");
            connection.setRequestProperty("Content-Type", "application/json");
            connection.setDoOutput(true);
            connection.getOutputStream().write(requestBody.toJSONString().getBytes());
            BufferedReader reader = new BufferedReader(new InputStreamReader(connection.getInputStream()));
            StringBuilder response = new StringBuilder();
            String line;
            while ((line = reader.readLine()) != null) {
                response.append(line);
            }
            reader.close();
            JSONObject json = JSON.parseObject(response.toString());
            return json.getJSONArray("choices").getJSONObject(0).getJSONObject("message").getString("content");
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }
}
